package scala.meta.internal.pc

import scala.jdk.CollectionConverters._

import com.sun.source.tree.CompilationUnitTree
import com.sun.source.util.JavacTask
import com.sun.source.util.Trees
import org.eclipse.lsp4j.Position
import org.eclipse.lsp4j.Range
import org.eclipse.lsp4j.TextEdit

object AutoImports {

  case class AutoImportEdits(
      identifierEdit: Option[TextEdit],
      importTextEdits: List[TextEdit]
  ) {
    def isEmpty: Boolean = identifierEdit.isEmpty && importTextEdits.isEmpty
  }

  def autoImportPosition(
      compiler: JavaMetalsGlobal,
      task: JavacTask,
      root: CompilationUnitTree,
      newImportName: String
  ): Position = {

    val sourcePositions = Trees.instance(task).getSourcePositions()
    val text = root.getSourceFile.getCharContent(true)
    val imports = root.getImports.asScala.toList

    if (imports.nonEmpty) {
      val (importsBefore, importsAfter) = imports.span { imp =>
        imp.getQualifiedIdentifier.toString < newImportName
      }

      val (endPos, lineOffset) = importsAfter.headOption match {
        case Some(_) if importsBefore.nonEmpty =>
          (sourcePositions.getEndPosition(root, importsBefore.last).toInt, 1)

        case Some(_) =>
          (sourcePositions.getStartPosition(root, importsAfter.head).toInt, 0)

        case None =>
          (sourcePositions.getEndPosition(root, imports.last).toInt, 1)
      }

      val pos = compiler.offsetToPosition(endPos, text.toString)
      new Position(pos.getLine + lineOffset, 0)
    } else {
      val packageName = root.getPackageName
      if (packageName != null) {
        val endPos = compiler.offsetToPosition(
          sourcePositions.getEndPosition(root, packageName).toInt,
          text.toString
        )
        new Position(endPos.getLine + 2, 0)
      } else {
        new Position(0, 0)
      }
    }
  }

  def computeAutoImportEdits(
      compiler: JavaMetalsGlobal,
      task: JavacTask,
      root: CompilationUnitTree,
      className: String,
      identifierRange: Range
  ): AutoImportEdits = {
    val simpleName = className.split('.').lastOption.getOrElse(className)
    val imported = existingSingleImports(root, simpleName)
    val wildcardImported = getWildcardImportPackages(root)
    val conflictingImports = imported - className

    if (conflictingImports.nonEmpty) {
      AutoImportEdits(Some(new TextEdit(identifierRange, className)), Nil)
    } else if (imported.contains(className)) {
      AutoImportEdits(None, Nil)
    } else if (isCoveredByWildcardImport(className, wildcardImported)) {
      AutoImportEdits(None, Nil)
    } else {
      val pos = autoImportPosition(compiler, task, root, className)
      val importText = s"import $className;\n"
      val edit = new TextEdit(new Range(pos, pos), importText)
      AutoImportEdits(None, List(edit))
    }
  }

  /**
   * Check if a fully-qualified class name is covered by a wildcard import.
   * Example: className "java.util.List" is covered by wildcard import "java.util.*"
   */
  def isCoveredByWildcardImport(
      className: String,
      wildcardImported: Set[String]
  ): Boolean = {
    val lastDot = className.lastIndexOf('.')

    if (lastDot < 0) {
      false
    } else {
      val packageName = className.substring(0, lastDot)
      wildcardImported.contains(packageName)
    }
  }

  /**
   * Extract all package names that have wildcard imports.
   * Example: "import java.util.*;" returns Set("java.util")
   */
  private def getWildcardImportPackages(
      root: CompilationUnitTree
  ): Set[String] = {
    root.getImports.asScala.iterator
      .map(_.getQualifiedIdentifier.toString)
      .filter(_.endsWith(".*"))
      .map(_.stripSuffix(".*"))
      .toSet
  }

  /**
   * Extract all fully-qualified class names that have single imports.
   * Example: "import java.util.List;" returns Set("java.util.List")
   */
  def existingSingleImports(
      root: CompilationUnitTree,
      simpleName: String
  ): Set[String] = {
    root.getImports.asScala.iterator
      .map(_.getQualifiedIdentifier.toString)
      .filterNot(_.endsWith(".*"))
      .flatMap { qualified =>
        val lastDot = qualified.lastIndexOf('.')
        val simple =
          if (lastDot >= 0) qualified.substring(lastDot + 1) else qualified
        if (simple == simpleName) Some(qualified) else None
      }
      .toSet
  }
}
